{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "private_outputs": true,
      "provenance": [],
      "machine_shape": "hm",
      "gpuType": "A100"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "## PaliGemma Fine-tuning\n",
        "\n",
        "In this notebook, we will fine-tune [pretrained PaliGemma](https://huggingface.co/google/paligemma-3b-pt-448) on a small split of [VQAv2](https://huggingface.co/datasets/HuggingFaceM4/VQAv2) dataset. Let's get started by installing necessary libraries."
      ],
      "metadata": {
        "id": "m8t6tkjuuONX"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FrKEBkmJtMan"
      },
      "outputs": [],
      "source": [
        "!pip install -q -U git+https://github.com/huggingface/transformers.git datasets accelerate"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "We will authenticate to access the model using `notebook_login()`."
      ],
      "metadata": {
        "id": "q_85okyYt1eo"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from huggingface_hub import notebook_login\n",
        "notebook_login()"
      ],
      "metadata": {
        "id": "NzJZSHD8tZZy"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Let's load the dataset."
      ],
      "metadata": {
        "id": "9_jUBDTEuw1j"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from datasets import load_dataset\n",
        "ds = load_dataset('HuggingFaceM4/VQAv2', split=\"train[:10%]\")\n"
      ],
      "metadata": {
        "id": "az5kdSbNpjgH"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "cols_remove = [\"question_type\", \"answers\", \"answer_type\", \"image_id\", \"question_id\"]\n",
        "ds = ds.remove_columns(cols_remove)"
      ],
      "metadata": {
        "id": "GEsDnBNmppIJ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "split_ds = ds.train_test_split(test_size=0.05) # we'll use a very small split for demo\n",
        "train_ds = split_ds[\"test\"]"
      ],
      "metadata": {
        "id": "wN1c9Aqhqt47"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "train_ds"
      ],
      "metadata": {
        "id": "TNJW2ty4yy4L"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Load the processor to preprocess the dataset."
      ],
      "metadata": {
        "id": "OsquATWQu2lJ"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from transformers import PaliGemmaProcessor\n",
        "model_id = \"google/paligemma-3b-pt-224\"\n",
        "processor = PaliGemmaProcessor.from_pretrained(model_id)"
      ],
      "metadata": {
        "id": "Zya_PWM3uBWs"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "We will preprocess our examples. We need to prepare a prompt template and pass the text input inside, pass it with batches of images to processor. Then we will set the pad tokens and image tokens to -100 to let the model ignore them. We will pass our preprocessed input as labels to make the model learn how to generate responses."
      ],
      "metadata": {
        "id": "QZROnV-pu7rt"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "device = \"cuda\"\n",
        "\n",
        "image_token = processor.tokenizer.convert_tokens_to_ids(\"<image>\")\n",
        "def collate_fn(examples):\n",
        "  texts = [\"answer \" + example[\"question\"] for example in examples]\n",
        "  labels= [example['multiple_choice_answer'] for example in examples]\n",
        "  images = [example[\"image\"].convert(\"RGB\") for example in examples]\n",
        "  tokens = processor(text=texts, images=images, suffix=labels,\n",
        "                    return_tensors=\"pt\", padding=\"longest\",\n",
        "                    tokenize_newline_separately=False)\n",
        "\n",
        "  tokens = tokens.to(torch.bfloat16).to(device)\n",
        "  return tokens\n"
      ],
      "metadata": {
        "id": "hdw3uBcNuGmw"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Our dataset is a very general one and similar to many datasets that PaliGemma was trained with. In this case, we do not need to fine-tune the image encoder, the multimodal projector but we will only fine-tune the text decoder."
      ],
      "metadata": {
        "id": "Hi_Y1blXwA04"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from transformers import PaliGemmaForConditionalGeneration\n",
        "import torch\n",
        "\n",
        "model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(device)\n",
        "\n",
        "for param in model.vision_tower.parameters():\n",
        "    param.requires_grad = False\n",
        "\n",
        "for param in model.multi_modal_projector.parameters():\n",
        "    param.requires_grad = False\n"
      ],
      "metadata": {
        "id": "iZRvrfUquH1y"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Alternatively, if you want to do LoRA and QLoRA fine-tuning, you can run below cells to load the adapter either in full precision or quantized."
      ],
      "metadata": {
        "id": "uCiVI-xUwSJm"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from transformers import BitsAndBytesConfig\n",
        "from peft import get_peft_model, LoraConfig\n",
        "\n",
        "bnb_config = BitsAndBytesConfig(\n",
        "        load_in_4bit=True,\n",
        "        bnb_4bit_quant_type=\"nf4\",\n",
        "        bnb_4bit_compute_type=torch.bfloat16\n",
        ")\n",
        "\n",
        "lora_config = LoraConfig(\n",
        "    r=8,\n",
        "    target_modules=[\"q_proj\", \"o_proj\", \"k_proj\", \"v_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
        "    task_type=\"CAUSAL_LM\",\n",
        ")\n",
        "model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, quantization_config=bnb_config, device_map={\"\":0})\n",
        "model = get_peft_model(model, lora_config)\n",
        "model.print_trainable_parameters()\n",
        "#trainable params: 11,298,816 || all params: 2,934,634,224 || trainable%: 0.38501616002417344\n"
      ],
      "metadata": {
        "id": "9AYeuyzNuJ9X"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "We will now initialize the `TrainingArguments`."
      ],
      "metadata": {
        "id": "logv0oLqwbIe"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from transformers import TrainingArguments\n",
        "args=TrainingArguments(\n",
        "            num_train_epochs=2,\n",
        "            remove_unused_columns=False,\n",
        "            per_device_train_batch_size=4,\n",
        "            gradient_accumulation_steps=4,\n",
        "            warmup_steps=2,\n",
        "            learning_rate=2e-5,\n",
        "            weight_decay=1e-6,\n",
        "            adam_beta2=0.999,\n",
        "            logging_steps=100,\n",
        "            optim=\"adamw_hf\",\n",
        "            save_strategy=\"steps\",\n",
        "            save_steps=1000,\n",
        "            push_to_hub=True,\n",
        "            save_total_limit=1,\n",
        "            output_dir=\"paligemma_vqav2\",\n",
        "            bf16=True,\n",
        "            report_to=[\"tensorboard\"],\n",
        "            dataloader_pin_memory=False\n",
        "        )\n"
      ],
      "metadata": {
        "id": "Il7zKQO9uMPT"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "We can now start training."
      ],
      "metadata": {
        "id": "8pR0EaGlwrDp"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from transformers import Trainer\n",
        "\n",
        "trainer = Trainer(\n",
        "        model=model,\n",
        "        train_dataset=train_ds ,\n",
        "        data_collator=collate_fn,\n",
        "        args=args\n",
        "        )\n"
      ],
      "metadata": {
        "id": "CguCGDv1uNkF"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "trainer.train()"
      ],
      "metadata": {
        "id": "9KFPQLrnF2Ha"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "trainer.push_to_hub()"
      ],
      "metadata": {
        "id": "O9fMDEjXSSzF"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "You can find steps to infer [here](https://colab.research.google.com/drive/100IQcvMvGm9y--oelbLfI__eHCoz5Ser?usp=sharing)."
      ],
      "metadata": {
        "id": "JohfxEJQjLBd"
      }
    }
  ]
}